import torch.nn as nn
from torchmetrics import MaxMetric
from torchmetrics import Accuracy
from torchmetrics import ConfusionMatrix
import torch
import logging

log = logging.getLogger(__name__)

class DenseNet(nn.Module):
    def __init__(self, numNodesPerLayerList):
        super().__init__()
        layerList = []
        for i in range(len(numNodesPerLayerList[:-2])):
            # vars(self)["linear"+str(i)] =
            layerList.append(nn.Linear(numNodesPerLayerList[i], numNodesPerLayerList[i+1]))
            layerList.append(nn.ReLU())
        layerList.append(nn.Linear(numNodesPerLayerList[i+1], numNodesPerLayerList[i+2]))
        # self.activation = nn.ReLU()
        self.layers = nn.Sequential(*layerList)
        self.loss = nn.MSELoss()

    def forward(self, x):
        # for layer in self.layerList:
        #     x = layer(x)
        #     x = self.activation(x)
        return self.layers(x)

    def step(self, batch):
        x, y = batch
        logits = self(x)
        loss = self.loss(logits, y)
        # preds = torch.argmax(logits, dim=1)
        return loss, logits, y

    def training_step(self, batch, batch_idx):
        loss, preds, targets = self.step(batch)

        self.log(f"teacher_selection_model/loss", loss, on_step=False, on_epoch=True, prog_bar=False)

        # self.log(
        #     f"train_client-{self.current_client_idx}/confusion_matrix",
        #     conf_mat, on_step=False, on_epoch=True, prog_bar=False
        # )
        # we can return here dict with any tensors
        # and then read it in some callback or in `training_epoch_end()`` below
        # remember to always return loss from `training_step()` or else backpropagation will fail!
        return {"loss": loss, "preds": preds, "targets": targets}

if __name__ == '__main__':
    model = DenseNet([10,15,20,30])
    print(model(torch.rand((1,10))))
    print(model)
    params = list(model.parameters())

    loss_function = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    current_loss = 0.0
    data = torch.rand((32,10))
    num_epochs = 10
    for i in range(num_epochs):
        # Zero the gradients
        optimizer.zero_grad()
        
        # Perform forward pass
        inputs = torch.rand((32,10))
        outputs = model(inputs)
        
        # Compute loss
        loss = loss_function(outputs, targets)
        
        # Perform backward pass
        loss.backward()
        
        # Perform optimization
        optimizer.step()
        current_loss += loss.item()